# scripts/step1_tstar.py
import argparse
import os
import json
import numpy as np
import pandas as pd

from src.present_act.gates import ThetaLadder, KappaLadder, StructuralGates, CRA
from src.present_act.lints import Lints
from src.present_act.engine import PresentActEngine, RunManifest
from src.present_act.scenes import (
    make_optics_scene,
    make_optics_roi,
    profile_from_roi,
    place_sources_from_s,
)
from src.utils.analysis import simple_regression, plot_series
from scripts._util import ensure_out, write_md, load_cfg


def measure_reach(scene, theta_mins, seeds, shots=64, min_commits=8):
    """
    Reach-threshold metric:
      For each theta_min, run N shots and count commits to ROI midline.
      stabilized = 1 if commits >= min_commits, else 0.
    """
    rows = []
    # Place sources deeper so small Theta cannot reach midline.
    (xL, y_src), (xR, _y_src) = place_sources_from_s(
        scene, s=max(2, scene.W // 3), y_row=max(1, int(0.60 * scene.H))
    )
    roi_mid = (scene.roi_bbox[1] + scene.roi_bbox[3]) // 2

    for seed in seeds:
        for tmin in theta_mins:
            theta = ThetaLadder([tmin, tmin + 1, tmin + 2])
            man = RunManifest(
                theta=theta,
                kappa=KappaLadder([2]),
                structural=StructuralGates(),
                cra=CRA(True),
                lints=Lints(),
                seed=seed,
            )
            screen = make_optics_roi(scene)
            eng = PresentActEngine(scene, man)

            commits = 0
            for _ in range(shots):
                # Use the DEEP source row (y_src) we set above (bug fix).
                src = (xL, y_src) if (np.random.rand() < 0.5) else (xR, y_src)
                cands = eng.propose_candidates([src], screen)
                acc, _res = eng.accept(cands)
                if acc is not None:
                    x, y = acc
                    screen[roi_mid, x] += 1
                    commits += 1

            rows.append(
                {
                    "theta_min": tmin,
                    "seed": seed,
                    "commits": int(commits),
                    "stabilized": int(commits >= min_commits),
                }
            )
    return pd.DataFrame(rows)


def main(cfg):
    out_step = ensure_out("out", "step1")
    # Config
    Louts = cfg["containers"]["L_out_list"]
    theta_mins = cfg["theta"]["min_bins"]
    seeds = cfg["common"]["seeds"]
    shots = int(cfg["common"]["shots"])
    min_commits = int(cfg["tstar"]["min_commits_threshold"])

    step_rows = []
    for L in Louts:
        scene = make_optics_scene(L, w=int(cfg["scene"]["w_inner_px"]))
        # No mask here — depth alone should create the step.
        df = measure_reach(scene, theta_mins, seeds, shots=shots, min_commits=min_commits)
        df["L_out"] = L
        df.to_csv(os.path.join(out_step, f"tstar_L{L}.csv"), index=False)

        agg = df.groupby("theta_min")["stabilized"].mean()
        first = agg[agg >= 0.5].index.min() if (agg >= 0.5).any() else np.nan
        step_rows.append({"L_out": L, "omega_star": (None if (first != first) else int(first))})

    step_df = pd.DataFrame(step_rows)
    step_df.to_csv(os.path.join(out_step, "tstar_steps.csv"), index=False)

    # Fit omega_star vs L_out → slope ≈ 1/c
    valid = step_df.dropna()
    if len(valid) >= 2:
        x = valid["L_out"].values.astype(float)
        y = valid["omega_star"].values.astype(float)
        reg = simple_regression(x, y)
        slope = float(reg["slope"])
        c = (1.0 / slope) if slope > 0 else float("nan")
        try:
            plot_series(
                x, y,
                xlab="L_out (pixels)",
                ylab="omega_star (first reachable theta_min)",
                title="T* step vs L_out",
                out_png=os.path.join(out_step, "omega_star_vs_Lout.png"),
            )
        except Exception:
            pass
    else:
        c = float("nan")

    with open(os.path.join("out", "calibration_time_hinge.json"), "w") as f:
        json.dump({"c": c, "omega_star_by_L": step_rows}, f, indent=2)

    write_md(
        os.path.join("out", "RESULTS_TSTAR.md"),
        "# STEP 1 — T* reach-threshold\n"
        f"shots={shots}, min_commits_threshold={min_commits}\n"
        f"calibrated c={c}\n\n" + step_df.to_string(index=False) + "\n",
    )
    print("T* step complete. c =", c)


if __name__ == "__main__":
    ap = argparse.ArgumentParser()
    ap.add_argument("--config", default="configs/low_compute.yaml")
    args = ap.parse_args()
    cfg = load_cfg(args.config)
    main(cfg)
